from functools import partial

import torch
from torch import nn
import torch.distributed as dist
from torch.distributed._tensor import Replicate, distribute_tensor, Shard
from torch.distributed.tensor.parallel import (ColwiseParallel,
                                               PrepareModuleInput,
                                               PrepareModuleOutput,
                                               RowwiseParallel,
                                               parallelize_module)

from xtuner._lite import get_logger
from xtuner._lite.parallel.megatron.utils import map_rank0_modules
from xtuner._lite.parallel.fsdp import lazy_init_megatron

logger = get_logger()


def megatron_internlm3_moe(model,
                       rank0_model,
                       experts_fsdp_mesh,
                       ep_mesh,
                       mp_policy=None,
                       recompute_ratio=1.0,
                       reshard_after_forward=True):
    if experts_fsdp_mesh.get_rank() == 0:
        rank0_map = map_rank0_modules(model, rank0_model)
    else:
        rank0_map = None

    param_init_fn = partial(
        lazy_init_megatron,
        rank0_map=rank0_map,
        dp_mesh=experts_fsdp_mesh,
    )

    from torch.distributed._composable import checkpoint
    from torch.distributed._composable.fsdp import fully_shard
    num_layers = len(model.layers)
    num_recompute_layers = int(num_layers * recompute_ratio)

    for i, block in enumerate(model.layers):

        block.apply(param_init_fn)

        fully_shard(
            block,
            mesh=experts_fsdp_mesh,
            mp_policy=mp_policy,
            reshard_after_forward=reshard_after_forward,
        )

        if i < num_recompute_layers:
            checkpoint(block)
    
    for layer_cur, layer_next in zip(model.layers[:-1], model.layers[1:]):
        layer_cur.set_modules_to_forward_prefetch([layer_next])

    model.embed_tokens.apply(param_init_fn)
    model.norm.apply(param_init_fn)
    if hasattr(model, 'rotary_emb'):
        model.rotary_emb.apply(param_init_fn)


def megatron_internlm3_moe_casual(model,
                       rank0_model,
                       experts_fsdp_mesh,
                       ep_mesh,
                       mp_policy=None,
                       recompute_ratio=1.0,
                       reshard_after_forward=True):
    megatron_internlm3_moe(
        model.model,
        rank0_model.model if experts_fsdp_mesh.get_rank() == 0 else None,
        experts_fsdp_mesh,
        ep_mesh,
        mp_policy=mp_policy,
        recompute_ratio=recompute_ratio,
        reshard_after_forward=reshard_after_forward
    )

    if experts_fsdp_mesh.get_rank() == 0:
        rank0_map = map_rank0_modules(model, rank0_model)
    else:
        rank0_map = None
    
    param_init_fn = partial(
        lazy_init_megatron,
        rank0_map=rank0_map,
        dp_mesh=experts_fsdp_mesh,
    )

    model.lm_head.apply(param_init_fn)

    from torch.distributed._composable.fsdp import fully_shard
    fully_shard(
        model,
        mesh=experts_fsdp_mesh,
        mp_policy=mp_policy,
        reshard_after_forward=False)
    
    model.set_modules_to_forward_prefetch([model.model.layers[0]])
